import sys
from numpy import fliplr
from sklearn.model_selection import train_test_split
sys.path.insert(0,'.')
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision.models import resnet
import pytorch_to_caffe

import os
from PIL import Image
import torchvision.transforms as transforms
from ResNet_LSTM import Res_LSTM # ResNet + LSTM部分
from net_linear_LSTM import NaiveLSTM # 使用PyTorch搭建的单层LSTM
from net_last_fc import FC_Net # 最后的fc


# 原example
# if __name__=='__main__':
#     name='resnet18'
#     resnet18=resnet.resnet18()
#     checkpoint = torch.load("/home/shining/Downloads/resnet18-5c106cde.pth")

#     resnet18.load_state_dict(checkpoint)
#     resnet18.eval()
#     input=torch.ones([1,3,224,224])
#      #input=torch.ones([1,3,224,224])
#     pytorch_to_caffe.trans_net(resnet18,input,name)
#     pytorch_to_caffe.save_prototxt('{}.prototxt'.format(name))
#     pytorch_to_caffe.save_caffemodel('{}.caffemodel'.format(name))

def read_images(frames, transform, folder_path):
    # assert len(os.listdir(folder_path)) >= self.frames, "Too few images in your data folder: " + str(folder_path)
    images = []
    start = 1 # 跳过第0帧  一般第0帧是一张全黑图
    step = int(len(os.listdir(folder_path))/frames)
    for i in range(frames):
        # image = Image.open(os.path.join(folder_path, '{:06d}.jpg').format(start+i*step))  #.convert('L')
        temp = os.path.join(folder_path, '{:03d}.jpg'.format(start+i*step))
        image = Image.open(temp)  #.convert('L')
        if transform is not None:
            image = transform(image)
        images.append(image)

    images = torch.stack(images, dim=0)
    # switch dimension for 3d cnn
    images = images.permute(1, 0, 2, 3)
    # print(images.shape)
    return images



# # 转换模型中Res50部分
# sample_size = 256
# sample_duration = 16
# num_classes = 101
# arch = "resnet50" # ResNet网络选择
# if __name__=='__main__':
#     name='res50_lstm-res'
#     # net = resnet.resnet50()
#     net = Res_LSTM(sample_size=sample_size, sample_duration=sample_duration, num_classes=num_classes, arch=arch) # 自己搭建的网络

#     # path = '/home/pytorch_to_caffe_master/res50lstm_epoch026.pth' # !!!!记得修改权重路径
#     path = r'D:\Works\Hisi_codes\SLR-master\res50lstm_models\res50lstm_epoch026-res.pth' # !!!!记得修改权重路径
    
#     checkpoint = torch.load(path) # !!!!记得修改权重路径
#     net.resnet.load_state_dict(checkpoint)
    
#     # # print(list(net.children()))
#     # modules = list(net.children())[:-2] # :-2  去除resnet后的lstm和全连接 
#     # # print("modules:\n", modules)
#     # net = nn.Sequential(*modules)

#     net.eval()
#     # input=torch.ones([16, 3, 256, 256])  # (batch_size=1, channel, t, h, w)
#     # input=torch.ones([1, 3, 16, 256, 256])  # (batch_size=1, channel, t, h, w)
#     input=torch.ones([1, 3, 256, 256])  # (batch_size=1, channel, h, w)
#     print("type(input):   ", type(input))
#     print(input.shape)

#     pytorch_to_caffe.trans_net(net.resnet, input, name)
#     pytorch_to_caffe.save_prototxt('{}.prototxt'.format(name))
#     pytorch_to_caffe.save_caffemodel('{}.caffemodel'.format(name))

# 转换自己搭建的LSTM部分
if __name__=='__main__':
    name='net_linear_lstm2'
    input_size = 2048 # LSTM输入尺寸
    hidden_size= 512  # 隐藏层尺寸
    net = NaiveLSTM(input_size=input_size, hidden_size=hidden_size) # 自己搭建的网络

    path = r'D:\Works\Hisi_codes\SLR-master\res50lstm_models\net_linear_lstm.pth' # !!!!记得修改权重路径
    
    checkpoint = torch.load(path) # !!!!记得修改权重路径
    net.load_state_dict(checkpoint)

    net.eval()

    input = torch.ones(1, 1, input_size)
    h0 = torch.ones(1, 1, hidden_size)
    c0 = torch.ones(1, 1, hidden_size)
    print("type(input):   ", type(input))
    print(input.shape)

    pytorch_to_caffe.trans_net_mylstm(net, input, (h0, c0), name)
    pytorch_to_caffe.save_prototxt('{}.prototxt'.format(name))
    pytorch_to_caffe.save_caffemodel('{}.caffemodel'.format(name))

# # 转换最后的全连接
# if __name__=='__main__':
#     name='net_last_fc'
#     in_size = 512  # 输入尺寸
#     out_size= 101  # 输出尺寸
#     net = FC_Net(in_size=in_size, out_size=out_size) # 创建网络

#     path = r'D:\Works\Hisi_codes\SLR-master\res50lstm_models\net_last_fc.pth' # !!!!记得修改权重路径
    
#     checkpoint = torch.load(path) # !!!!记得修改权重路径
#     net.load_state_dict(checkpoint)

#     net.eval()

#     input = torch.ones(1, in_size)
#     out = net(input)
#     print("out.shape:", out.shape)
#     # print("out:\n", out)

#     pytorch_to_caffe.trans_net(net, input, name)
#     pytorch_to_caffe.save_prototxt('{}.prototxt'.format(name))
#     pytorch_to_caffe.save_caffemodel('{}.caffemodel'.format(name))
    